{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FAIRSeq in Amazon SageMaker: Translation task - German to English - Distributed / multi machine training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Facebook AI Research (FAIR) Lab made available through the [FAIRSeq toolkit](https://github.com/pytorch/fairseq) their state-of-the-art Sequence to Sequence models. \n",
    "\n",
    "In this notebook, we will show you how to train a German to English translation model using a fully convolutional architecture on multiple GPUs and machines.\n",
    "\n",
    "## Permissions\n",
    "\n",
    "Running this notebook requires permissions in addition to the regular SageMakerFullAccess permissions. This is because it creates new repositories in Amazon ECR. The easiest way to add these permissions is simply to add the managed policy AmazonEC2ContainerRegistryFullAccess to the role that you used to start your notebook instance. There's no need to restart your notebook instance when you do this, the new permissions will be available immediately."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare dataset\n",
    "\n",
    "To train the model, we will be using the IWSLT'14 dataset as descibed [here](https://github.com/pytorch/fairseq/tree/master/examples/translation#prepare-iwslt14sh). This was used in the IWSLT'14 German to English translation task: [\"Report on the 11th IWSLT evaluation campaign\" by Cettolo et al](http://workshop2014.iwslt.org/downloads/proceeding.pdf).\n",
    "\n",
    "First, we'll download the dataset and start the pre-processing. Among other steps, this pre-processing cleans the tokens and applys BPE encoding as you can see [here](https://github.com/pytorch/fairseq/blob/master/examples/translation/prepare-iwslt14.sh)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%sh\n",
    "cd data\n",
    "chmod +x prepare-iwslt14.sh\n",
    "\n",
    "# Download dataset and start pre-processing\n",
    "./prepare-iwslt14.sh"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next step is to apply the second set of pre-processing, which binarizes the dataset based on the source and target language. Full information on this script [here](https://github.com/pytorch/fairseq/blob/master/preprocess.py).  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%sh\n",
    "\n",
    "# First we download fairseq in order to have access to the scripts\n",
    "git clone https://github.com/pytorch/fairseq.git fairseq-git\n",
    "cd fairseq-git\n",
    "\n",
    "# Binarize the dataset:\n",
    "TEXT=../data/iwslt14.tokenized.de-en\n",
    "python preprocess.py --source-lang de --target-lang en \\\n",
    "  --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \\\n",
    "  --destdir ..data/iwslt14.tokenized.de-en"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The dataset is now all prepared for training on one of the FAIRSeq translation models. The next step is upload the data to Amazon S3 in order to make it available for training."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Upload data to Amazon S3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "\n",
    "sagemaker_session = sagemaker.Session()\n",
    "region =  sagemaker_session.boto_session.region_name\n",
    "account = sagemaker_session.boto_session.client('sts').get_caller_identity().get('Account')\n",
    "\n",
    "bucket = sagemaker_session.default_bucket()\n",
    "prefix = 'sagemaker/DEMO-pytorch-fairseq/datasets/iwslt14'\n",
    "\n",
    "role = sagemaker.get_execution_role()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = sagemaker_session.upload_data(path='data/iwslt14.tokenized.de-en', bucket=bucket, key_prefix=prefix)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we need to register a Docker image in Amazon SageMaker that will contain the FAIRSeq code and that will be pulled at training and inference time to perform the respective training of the model and the serving of the precitions. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build FAIRSeq Translation task container"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%sh\n",
    "chmod +x create_container.sh \n",
    "\n",
    "./create_container.sh pytorch-fairseq"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The FAIRSeq image has been pushed into Amazon ECR, the registry from which Amazon SageMaker will be able to pull that image and launch both training and prediction. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training on Amazon SageMaker\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we will set the hyper-parameters of the model we want to train. Here we are using the recommended ones from the [FAIRSeq example](https://github.com/pytorch/fairseq/tree/master/examples/translation#prepare-iwslt14sh). The full list of hyper-parameters available for use can be found [here](https://fairseq.readthedocs.io/en/latest/command_line_tools.html). Please note you can use dataset, training, and generation parameters. For the distributed backend, **gloo** is the only supported option and is set as default. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparameters = {\n",
    "    \"lr\": 0.25,    \n",
    "    \"clip-norm\": 0.1,\n",
    "    \"dropout\": 0.2,\n",
    "    \"max-tokens\": 4000,\n",
    "    \"criterion\": \"label_smoothed_cross_entropy\",\n",
    "    \"label-smoothing\": 0.1,\n",
    "    \"lr-scheduler\": \"fixed\",\n",
    "    \"force-anneal\": 200,\n",
    "    \"arch\": \"fconv_iwslt_de_en\"\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are ready to define the Estimator, which will encapsulate all the required parameters needed for launching the training on Amazon SageMaker. \n",
    "\n",
    "For training, the FAIRSeq toolkit recommends to train on GPU instances, such as the `ml.p3` instance family [available in Amazon SageMaker](https://aws.amazon.com/sagemaker/pricing/instance-types/). In this example, we are training on 2 instances."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.estimator import Estimator\n",
    "\n",
    "algorithm_name = \"pytorch-fairseq\"\n",
    "image = '{}.dkr.ecr.{}.amazonaws.com/{}:latest'.format(account, region, algorithm_name)\n",
    "\n",
    "estimator = Estimator(image,\n",
    "                     role,\n",
    "                     train_instance_count=2,\n",
    "                     train_instance_type='ml.p3.8xlarge',\n",
    "                     train_volume_size=100, \n",
    "                     output_path='s3://{}/output'.format(bucket),\n",
    "                     hyperparameters=hyperparameters)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The call to fit will launch the training job and regularly report on the different performance metrics related to the training. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "estimator.fit(inputs=inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the model has finished training, we can go ahead and test its translation capabilities by deploying it on an endpoint.\n",
    "\n",
    "## Hosting the model\n",
    "\n",
    "We first need to define a base JSONPredictor class that will help us with sending predictions to the model once it's hosted on the Amazon SageMaker endpoint. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer\n",
    "\n",
    "class JSONPredictor(RealTimePredictor):\n",
    "    def __init__(self, endpoint_name, sagemaker_session):\n",
    "        super(JSONPredictor, self).__init__(endpoint_name, sagemaker_session, json_serializer, json_deserializer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now use the estimator object to deploy the model artificats (the trained model), and deploy it on a CPU instance as we no longer need a GPU instance for simply infering from the model. Let's use a `ml.m5.xlarge`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.m5.xlarge', predictor_cls=JSONPredictor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now it's your time to play. Input a sentence in German and get the translation in English by simply calling predict. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import html\n",
    "\n",
    "text_input = 'Guten Morgen'\n",
    "\n",
    "result = predictor.predict(text_input)\n",
    "#  Some characters are escaped HTML-style requiring to unescape them before printing\n",
    "print(html.unescape(result))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once you're done with getting predictions, remember to shut down your endpoint as you no longer need it. \n",
    "\n",
    "## Delete endpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sagemaker_session.delete_endpoint(predictor.endpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Voila! For more information, you can check out the [FAIRSeq toolkit homepage](https://github.com/pytorch/fairseq). "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_pytorch_p36",
   "language": "python",
   "name": "conda_pytorch_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
